import torch
 
# 读取模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = torch.load("checkpoints/best_0.8487430065870285.pth", map_location=device)
model.eval()
model.to(device)
 
# 创建输入tensor
dummy_input = torch.ones(1, 3, 360, 640).to('cuda:0')
 
# 生成pt
trace_model = torch.jit.trace(model, dummy_input)
trace_model.save('unet.pt')